iT邦幫忙

3

計算 LLM 的 Perplexity 困惑度

  • 分享至 

  • xImage
  •  

簡介

困惑度 (Perplexity, PPL) 是個評估語言模型相當實用的指標,用來表示語言模型對一句話的困惑程度。什麼叫困惑程度呢?當我們看到一句話會產生困惑時,代表這句話可能:

  1. 包含不合理的資訊。
  2. 沒頭沒尾含糊不清。
  3. 使用的文法構句不正常。
  4. 邏輯上互斥矛盾。
  5. 看到根本沒學過的語言!

除此之外,還有很多情況會令人感到困惑。對於人類而言是如此,對於語言模型而言亦是如此。一個 Decoder LLM 會透過 Autoregressive 的機制不斷生成字詞來構成語句,在這個過程中,LLM 會不斷生成一份機率表,根據這份機率表來取樣決定下個字詞是什麼。

如果這個機率表顯示「下個字是 X 的機率為 100% 🤗」那就代表 LLM 非常明確,這時 PPL 就會相對較低。但如果 LLM 覺得「嗯…好像每個字都有可能 🤔」那就代表 LLM 相當困惑,這時 PPL 就會較高。

實做

一般 LLM 都是以交叉熵 (Cross Entropy) 當作損失函數 (Loss Function),而 PPL 就是對損失值 (Loss)指數函數 (Exponential Function) 的結果。那實際上到底要如何計算 PPL 呢?首先,要先來瞭解如何取得單次推論時的 Loss,在 HF Transformers 裡面可以這樣做:

model = AutoModelForCausalLM.from_pretrained(...)
tk = AutoTokenizer.from_pretrained(...)
input_ids = tk.encode(...)

outputs = model.forward(input_ids=input_ids, labels=input_ids)
print(outputs.loss)

只需要多指定一個 labels 參數,並且代入原本的 input_ids 就能得到 Loss 了,接下來對 Loss 取指數函數就能得到 PPL:

ppl = torch.exp(outputs.loss)

最簡單的評測方法,是固定序列長度來計算 PPL,以 Wikitext 資料集為例,首先透過 Tokenizer 進行分詞:

tk = AutoTokenizer.from_pretrained(...)
dataset = load_dataset(
    "wikitext",
    "wikitext-2-raw-v1",
    split="test",
)

input_ids = list()
for item in dataset:
    text = item["text"] + "\n"
    tokens = tk.encode(text, add_special_tokens=False)
    input_ids.extend(tokens)

這裡先設定序列長度為 2048 來進行評估:

import torch

seqlen = 2048
data_size = len(input_ids) // seqlen  # 計算序列數量
input_ids = input_ids[: data_size * seqlen]  # 捨棄最後一筆

其中長度不足 2048 的最後一筆資料會被捨棄,然後將 input_ids 轉為 Tensor,並在每個序列的開頭加上 BOS Token,最後把 Tensor 移動到對應的裝置上:

input_ids = torch.LongTensor(input_ids).view(data_size, seqlen)
bos_token = torch.full(
    (data_size, 1),
    tk.bos_token_id,
    dtype=torch.int64,
)
input_ids = torch.concat((bos_token, input_ids), dim=1)
input_ids = input_ids.to(model.device)

接下來開始對測試資料進行推論:

nlls = list()
for i in range(data_size):
    batch = input_ids[i : i + 1]
    outputs = model.forward(batch, labels=batch)
    nlls.append(outputs.loss)
ppl = torch.exp(torch.stack(nlls).mean())
print(ppl)

nlls 代表 Negative Log-Likelihood,與 Cross Entropy 基本上是等價的概念。因為評估的過程可能會花上一段時間,所以可以借助 tqdm 套件來顯示評估進度:

from tqdm import trange

nlls = list()
batch_size = 1
with trange(0, data_size, batch_size) as prog:
    for i in prog:
        batch = input_ids[i : i + batch_size]
        outputs = model.forward(batch, labels=batch)
        nlls.append(outputs.loss)
        ppl = torch.exp(torch.stack(nlls).mean())
        prog.desc = f"ppl: {ppl:.4f}"

測試

筆者實測 TinyLlama 1B Chat 的 PPL 在序列長度 2K 時為 8.0233,而序列長度 1K 時則為 8.9707,由此可見序列長度對 PPL 也有影響,可以想像比較短的序列,上下文的資訊比較不足,就很像看到沒頭沒尾的一句話,自然也會感到比較困惑一些。

一般而言,測試的序列長度在模型的 Context Window 之內時,序列越長 PPL 越低,但如果序列長度超過訓練長度太多時,困惑度就會開始無情暴漲。例如 TinyLlama 在訓練時的 Context Window 設定為 2048,這時拿 4096 的序列來測試他的 PPL,就會得到 139.8860 這樣非常高的結果。

場景

PPL 評估相較於其他 LLM 的評估指標而言相對簡單許多,但 PPL 本身並不代表任何實際概念上的準確率之類的,所以在比較上需要特別注意。某些情況下特別適合用 PPL 做比較:

對同一個模型,做不同方法、層級的量化。

例如用 Llama3 8B 來比較 GPTQ 與 AWQ 之間的 PPL 好壞,或者比較 HQQ 8-Bit 與 HQQ 4-Bit 之間的 PPL 差異。

用同樣的訓練資料,比較不同架構、參數量的模型。

例如比較 LLaMA 7B, 13B, 33B, 65B 等不同參數量的模型,因為他們使用的訓練資料大致相同,這樣的比較就是有意義的。

但如果訓練資料不同,比較 PPL 的意義就相對小一點。例如 Llama3 在 Wikitext 上的 PPL 就比 Llama2 高一點 (6.5 vs 5.5),但實際使用上 Llama3 的效果是比 Llama2 好的多。

此外,當一個模型的 Vocab Size 很大時,其 Logits 的分佈會更零散,因此計算出來的 PPL 也會相對較高。所以像 Llama 3 的 Vocab Size 高達 128256,相較於 Llama 2 的 32000 而言,用相同資料集評估出來的 PPL 在 Llama 3 這邊可能反而比較高,然而 Llama 3 模型本身效能其實是比較好的。

擴展 Context Window 的實驗。

例如嘗試把 TinyLlama 的 Context Window 從 2K 擴展到 4K 或 8K 時,透過 PPL 來評估不同長度、不同擴展方法的效果就相當方便。

做選擇題的評估。

這是一個滿常見的做法,諸如 MMLU 等評估資料集,大多是給一個題目與多個選項,讓模型去選擇正確的選項。除了嘗試讓模型直接回答出該選項以外,其中一種做法就是用 PPL 來看模型更偏好哪個答案:

choices = [
    "A banana is red.",
    "A banana is yellow.",
    "A banana is blue.",
    "A banana is green.",
]

for text in choices:
    input_ids = tk.encode(text, return_tensors="pt")
    outputs = model.forward(input_ids=input_ids, labels=input_ids)
    print(outputs.loss, text)

雖然說是比較 PPL,但其實只要看誰的 Loss 比較低就好,畢竟兩者是正相關的,最後輸出結果如下:

5.9140 A banana is red.
4.9455 A banana is yellow.
5.8184 A banana is blue.
5.7485 A banana is green.

可以看到,黃色的香蕉 Loss 最低,所以評估上會判定模型傾向於採納這個選項。但因為這是自動評估的關係,所以實際上模型是不是真的這樣認為又是另外一回事了。


圖片
  直播研討會
圖片
{{ item.channelVendor }} {{ item.webinarstarted }} |
{{ formatDate(item.duration) }}
直播中

尚未有邦友留言

立即登入留言